import numpy as np
import copy

crossing_box = [[' ',' ',' '],['\\',' ','/'],[' ','\\',' '],['/',' ','\\']]

class Braid:
  #n=number of strands, word=list of sigma_i (represented by int i) and sigma_i^-1 (represented by int -i)
  def __init__(self, n, word):
    self.n = n
    self.word = word

  def __eq__(self,other):
    return (self.n==other.n and self.word==other.word)

  def __repr__(self):
    return str(self.word)

  def printBraid(self):
    if (self.n==0):
      print(" ")
      return
    res = [[" " for j in range(len(self.word)*5)] for i in range(3*self.n-2)]
    for i in range(len(self.word)):
      for j in range(0,3*self.n-2,3):
        for k in range(5*i,5*i+5):
          res[j][k] = '_'
      strand = 3*(self.n-1-abs(self.word[i]))
      for j in range(4):
        for k in range(3):
          res[strand+j][5*i+1+k] = crossing_box[j][k]
      if self.word[i]<0:
        res[strand+2][5*i+2] = '/'
    for line in res:
      for char in line:
        print(char, end='')
      print('')

  #returns a pair np, nn where np = # of positive crossings, nn = # of negative crossings
  def countCrossings(self):
    np = 0
    nn = 0
    for cross in self.word:
      if cross>0:
        np += 1
      else:
        nn += 1
    return np, nn

  def mirror(self):
    self.word.reverse()
    for i in range(len(self.word)):
      self.word[i] = -1*self.word[i]

  #does a saddle move across a crossing, then a Reidemeister I move
  def saddleMove(self,crossing):
    self.word = self.word[0:crossing-1]+self.word[crossing:]

  #does a reidemeister II move to cancel adjacent inverse elements
  def reidemeister2(self,crossing):
    if self.word[crossing-1]!=-self.word[crossing]:
      print("invalid Reidemeister II move")
      return
    self.word = self.word[0:crossing-1]+self.word[crossing+1:]

  #shifts crossings circularly by s
  def circularshift(self,s):
    #TODO
    return

idx = 0
class Node:
  def __init__(self,prev=None,next=None,data=None):
    global idx
    self.prev = prev
    self.next = next
    if data is None:
      self.data = idx
      idx += 1
    else:
      self.data = data 

  def __eq__(self, other):
    return self.data==other.data

  def __repr__(self):
    return str(self.data)

class Graph:
  def __init__(self,nodes=None):
    self.nodes = nodes

class Smoothing:
  #contains braid, smooth, graph, circles
  def __init__(self, braid, smooth):
    global idx
    idx = 0
    self.braid = copy.deepcopy(braid)
    self.smooth = smooth
    self.graph = None
    self.circles = None
    self.setCircles()

  def setCircles(self):
    global idx
    idx = 0
    rows = [[] for i in range(self.braid.n)]
    for i in range(len(self.braid.word)):
      row = self.braid.n-abs(self.braid.word[i])-1
      if (len(rows[row])==0):
        rows[row].append(Node())
      else:
        rows[row].append(Node(prev=rows[row][-1]))
        rows[row][-2].next = rows[row][-1]
      rows[row].append(Node())
      if (len(rows[row+1])==0):
        rows[row+1].append(Node())
      else:
        rows[row+1].append(Node(prev=rows[row+1][-1]))
        rows[row+1][-2].next = rows[row+1][-1]
      rows[row+1].append(Node())
      if (((self.braid.word[i]>0) and (self.smooth[i]=='0')) or ((self.braid.word[i]<0) and (self.smooth[i]=='1'))):
        rows[row][-2].next = rows[row][-1]
        rows[row][-1].prev = rows[row][-2]
        rows[row+1][-2].next = rows[row+1][-1]
        rows[row+1][-1].prev = rows[row+1][-2]
      else:
        rows[row][-2].next = rows[row+1][-2]
        rows[row][-1].prev = rows[row+1][-1]
        rows[row+1][-2].next = rows[row][-2]
        rows[row+1][-1].prev = rows[row][-1]
    graphnodes = []
    for row in rows:
      if (len(row)>0):
        row[0].prev = row[-1]
        row[-1].next = row[0]
      for nod in row:
        graphnodes.append(nod)
    self.graph = Graph(graphnodes)
    tmp = copy.deepcopy(graphnodes)
    self.circles = []
    while (len(tmp)!=0):
      newcirc = [tmp[0]]
      tmp.pop(0)
      while True:
        last = newcirc[-1]
        if (not last.next in newcirc):
          newcirc.append(last.next)
          tmp.remove(last.next)
        elif (not last.prev in newcirc):
          newcirc.append(last.prev)
          tmp.remove(last.prev)
        else:
          break
      self.circles.append(newcirc)

  def __eq__(self,other):
    return (self.braid==other.braid and self.smooth==other.smooth)

  def __repr__(self):
    return str(self.circles)

  #returns the homological grading of the smoothing
  def homgrading(self):
    np, nn = self.braid.countCrossings()
    height = 0
    for char in self.smooth:
      if char=='1':
        height += 1
    return height-nn

  def mirror(self):
    self.braid.mirror()
    sm = ''
    for i in range(len(self.smooth)-1,-1,-1):
      if (self.smooth[i]=='0'):
        sm += '1'
      else:
        sm += '0'
    self.smooth = sm
    self.setCircles()

  #returns a sorted list of integers representing which crossings are still present
  def checkCrossings(self):
    res = []
    for nod in self.graph.nodes:
      if ((nod.data//4) not in res):
        res.append(nod.data//4)
    res.sort()
    return res

  #does a saddle move across a crossing, then a Reidemeister I move
  def saddleMove(self,crossing):
    #finds the four vertices of the crossing and gets rid of them, making the resulting smoothing have two horizontal lines at the former crossing
    crossnodes = [None,None,None,None]
    crossnumbers = self.checkCrossings()
    for nod in self.graph.nodes:
      if (nod.data//4==crossnumbers[crossing-1]):
        i = nod.data%4
        crossnodes[i] = nod
    crossnodes[0].prev.next = crossnodes[1].next
    crossnodes[1].next.prev = crossnodes[0].prev
    crossnodes[2].prev.next = crossnodes[3].next
    crossnodes[3].next.prev = crossnodes[2].prev
    for nod in crossnodes:
      self.graph.nodes.remove(nod)
      for circ in self.circles:
        if nod in circ:
          circ.remove(nod)
          break
    #removes the crossing from the braid
    self.braid.saddleMove(crossing)
    self.smooth = self.smooth[0:crossing-1]+self.smooth[crossing:]

  #does a Reidemeister II move to cancel adjacent inverse elements
  def reidemeister2(self,crossing):
    crossnumbers = self.checkCrossings()
    if self.braid.word[crossing-1]!=-self.braid.word[crossing]:
      print("invalid Reidemeister II move")
      return
    crossnodes = [None,None,None,None,None,None,None,None]
    for nod in self.graph.nodes:
      if (nod.data//4==crossnumbers[crossing-1] or nod.data//4==crossnumbers[crossing]):
        if (nod.data//4==crossnumbers[crossing-1]):
          i = nod.data%4
        else:
          i = (nod.data%4)+4
        crossnodes[i] = nod
    crossnodes[0].prev.next = crossnodes[5].next
    crossnodes[5].next.prev = crossnodes[0].prev
    crossnodes[2].prev.next = crossnodes[7].next
    crossnodes[7].next.prev = crossnodes[2].prev   
    for nod in crossnodes:
      self.graph.nodes.remove(nod)
      for circ in self.circles:
        if nod in circ:
          circ.remove(nod)
          break
    self.braid.reidemeister2(crossing) 

def mirrorNumber(x,c):
  return 4*(c - x//4 - 1) + (1 - x%2) + 2*((x%4)//2)

def flipLabel(l):
  if (l=='x'):
    return '1'
  elif (l=='1'):
    return 'x'
  else:
    return '0'

#labels are encoded with a string of '1's and 'x's
class LabeledState:
  def __init__(self,smoothing,labels):
    self.smoothing = smoothing
    self.labels = labels
    self.labeledcircles = []
    for i in range(len(smoothing.circles)):
      self.labeledcircles.append([labels[i],smoothing.circles[i]])

  def __eq__(self,other):
    return (self.smoothing==other.smoothing and self.labels==other.labels)

  def isZero(self):
    for circ in self.labeledcircles:
      if circ[0]=='0':
        return True
    return False

  #returns the label ('1' or 'x') of the circle that idx (label of node) belongs to
  def findLabel(self,idx):
    for circ in self.labeledcircles:
      for nod in circ[1]:
        if nod.data==idx:
          return circ[0]
    return '1'

  def __repr__(self):
    #return str(self.labeledcircles)
    braididx = self.smoothing.braid.n
    ncrossings = len(self.smoothing.braid.word)
    idxlen = len(str(4*ncrossings))
    ndash = 3
    l = idxlen*2*ncrossings + (2*ncrossings-1)*ndash + 2*(braididx + 1)
    h = 4*braididx-1
    res = [[" " for j in range(l)] for i in range(h)]
    #closes circles on the left
    for i in range(braididx):
      for j in range(i+1,braididx+1):
        res[2*i][j] = '-'
        res[4*braididx-2-2*i][j] = '-'
      for j in range(2*i+1,4*braididx-2-2*i):
        res[j][i] = '|'
    #fills middle with horizontal dashes
    for i in range(braididx+1, l-braididx-1):
      for j in range(0,h,2):
        res[j][i] = '-'
    #closes circles on the right
    for i in range(braididx):
      for j in range(l-braididx-1,l-1-i):
        res[2*i][j] = '-'
        res[4*braididx-2-2*i][j] = '-'
      for j in range(2*i+1,4*braididx-2-2*i):
        res[j][l-i-1] = '|'
    #sorted list of node indices
    nodelist = []
    for circ in self.labeledcircles:
      for nod in circ[1]:
        nodelist.append(nod.data)
    nodelist.sort()
    #goes through crossings and adds node labels, connects nodes, and adds circle labels
    for i in range(ncrossings):
      move = self.smoothing.braid.word[i]
      strand = abs(move)
      resol = self.smoothing.smooth[i]
      if ((resol=='0' and move<0) or (resol=='1' and move>0)):
        for j in range(2):
          for k in range(ndash):
            res[2*(braididx-strand-j)][braididx+1+i*(2*idxlen+2*ndash)+idxlen+k] = ' '
          res[2*(braididx-strand)-1][braididx+1+i*(2*idxlen+2*ndash)] = '|'
          res[2*(braididx-strand)-1][braididx+1+i*(2*idxlen+2*ndash)+1] = self.findLabel(4*i)
          res[2*(braididx-strand)-1][braididx+1+i*(2*idxlen+2*ndash)+idxlen+ndash] = '|'
          res[2*(braididx-strand)-1][braididx+1+i*(2*idxlen+2*ndash)+idxlen+ndash+1] = self.findLabel(4*i+1)
      else:
        res[2*(braididx-strand)-1][braididx+1+i*(2*idxlen+2*ndash)+1] = self.findLabel(4*i)
        res[2*(braididx-strand)+1][braididx+1+i*(2*idxlen+2*ndash)+1] = self.findLabel(4*i+2)
      for j in range(4):
        nod = 4*i+j
        if nod in nodelist:
          for k in range(idxlen):
            tmp = ' '
            if (k<len(str(nod))):
              tmp = str(nod)[k]
            if j<2:
              res[2*(braididx-strand-1)][braididx+1+i*(2*idxlen+2*ndash)+(j%2)*(idxlen+ndash)+k] = tmp
            else:
              res[2*(braididx-strand)][braididx+1+i*(2*idxlen+2*ndash)+(j%2)*(idxlen+ndash)+k] = tmp

    resstr = ""
    for row in res:
      for ch in row:
        resstr = resstr + ch
      resstr = resstr + '\n'
    return resstr


  #returns a pair h,q where h is the homological grading, q is the quantum grading
  def getGrading(self):
    np, nn = self.smoothing.braid.countCrossings()
    h = self.smoothing.homgrading()
    vp = 0
    vn = 0
    for char in self.labels:
      if char=='1':
        vp += 1
      else:
        vn += 1
    q = vp - vn + h + np - nn
    return h,q
  
  def mirror(self):
    #list containing pairs of [label,number of a node] for each labeled circle
    circlabels = []
    nulllabels = []
    for circ in self.labeledcircles:
      if (len(circ[1])>0):
        circlabels.append([circ[0],circ[1][0].data])
      else:
        nulllabels.append(circ[0])
    self.smoothing.mirror()
    newlabeledcircles = []
    count = 0
    for circ in self.smoothing.circles:
      if (len(circ)==0):
        newlabeledcircles.append([flipLabel(nulllabels[count]),circ])
        count += 1
      else:
        for nod in circ:
          found = False
          for label in circlabels:
            if nod.data==mirrorNumber(label[1],len(self.smoothing.braid.word)):
              newlabeledcircles.append([flipLabel(label[0]),circ])
              found = True
              break
          if found:
            break
    self.labeledcircles = newlabeledcircles
    self.fixLabels()

    
  #changes self.labels to match labels in self.labeledcircles
  def fixLabels(self):
    tmp = ''
    for circ in self.labeledcircles:
      tmp += circ[0]
    self.labels = tmp

  def fixSmoothing(self):
    circles = []
    for circ in self.labeledcircles:
      circles.append(circ[1])
    self.smoothing.circles = circles
    graphnodes = []
    for circ in circles:
      graphnodes += circ
    self.smoothing.graph = Graph(graphnodes)
    sm = ['1']*len(self.smoothing.braid.word)
    for nod in graphnodes:
      if (nod.data%4==0):
        nxt = nod.next.data
        cross = self.smoothing.braid.word[nod.data//4]
        if ((nxt%4==1 and cross>0) or (nxt%4==2 and cross<0)):
          sm[nod.data//4] = '0'
    s = ''
    for char in sm:
      s += char
    self.smoothing.smooth = s


def whichCircle(i):
  if (i%4 <= 1):
    return 0
  else:
    return 1

#an element of the Khovanov chain complex, represented as a list of labeled states and a corresponding list of coefficients
#labeled states can be 0 and the coefficients can be 0, labeled states do not have to be unique
class Ccelement:
  def __init__(self,states,coefficients):
    self.states = states
    self.coefficients = coefficients

  #returns a new Ccelement where the states are unique and 0 elements are removed
  def simplify(self):
    if len(self.states)==0:
      return self
    tmpstates = []
    tmpcoeffs = []
    #removes 0 elements
    for i in range(len(self.states)):
      if ((not self.states[i].isZero()) and (self.coefficients[i]!=0)):
        tmpstates.append(self.states[i])
        tmpcoeffs.append(self.coefficients[i])
    #combines like terms
    resstates = []
    rescoeffs = []
    while len(tmpstates)!=0:
      """
      st = tmpstates.pop(0)
      compare = [(st==resstates[i]) for i in range(len(resstates))]
      print(compare)
      if not any(compare):
        resstates.append(st)
        rescoeffs.append(tmpcoeffs.pop(0))
      else:
        idx = compare.index(True)
        rescoeffs[idx] += tmpcoeffs.pop(0)
      """
      st = tmpstates.pop(0)
      if not (st in resstates):
        resstates.append(st)
        rescoeffs.append(tmpcoeffs.pop(0))
      else:
        idx = resstates.index(st)
        rescoeffs[idx] += tmpcoeffs.pop(0)
      
    res = Ccelement(resstates,rescoeffs)
    return res

  def __eq__(self,other):
    tmp1 = self.simplify()
    tmp2 = other.simplify()
    if (len(tmp1.states)!=len(tmp2.states)):
      return False
    l = len(tmp1.states)
    for i in range(l):
      found = False
      for j in range(l):
        if (tmp1.states[i]==tmp2.states[j] and tmp1.coefficients[i]==tmp2.coefficients[j]):
          found = True
          break
      if not found:
        return False
    return True

  def __repr__(self):
    if len(self.states)==0:
      if (len(self.coefficients)==1):
        return str(self.coefficients[0])
      else:
        return '0'
    res = ''
    for i in range(len(self.states)):
      res += str(self.coefficients[i])+"*\n"+str(self.states[i])
      if i<len(self.states)-1:
        res += '\n + '
    return res

  def isZero(self):
    tmp = self.simplify()
    for c in tmp.coefficients:
      if c!=0:
        return False
    return True

  def getGrading(self):
    if len(self.states)==0:
      return 0,0
    return self.states[0].getGrading()

  def mirror(self):
    for st in self.states:
      st.mirror()

  #does a saddle move across a particular crossing for all terms in the sum
  def saddleMove(self,crossing):
    resstates = []
    rescoeffs = []
    for k in range(len(self.states)):
      st = self.states[k]
      tmpstates = []
      tmpcoeffs = []
      crossnodes = [None,None,None,None]
      crossnumbers = st.smoothing.checkCrossings()
      for nod in st.smoothing.graph.nodes:
        if (nod.data//4==crossnumbers[crossing-1]):
          i = nod.data%4
          crossnodes[i] = nod
      #note: importantcircles has either one or two labeled circles. the first circle always contains the upper left node
      importantcircles = []
      importantidx = [None,None,None,None]
      circno = 0
      for circ in st.labeledcircles:
        found = False
        for i in range(len(crossnodes)):
          if crossnodes[i] in circ[1]:
            if not circ in importantcircles:
              importantcircles.append(circ)
            importantidx[i] = [circno,circ[1].index(crossnodes[i])]
            found = True
        if found:
          circno += 1        
      importantnodes = [[None,None],[None,None]]
      for i in range(2):
        for j in range(2):
          importantnodes[i][j] = importantcircles[importantidx[2*i+j][0]][1][importantidx[2*i+j][1]]
      importantnodes[0][0].prev.next = importantnodes[0][1].next
      importantnodes[0][1].next.prev = importantnodes[0][0].prev
      importantnodes[1][0].prev.next = importantnodes[1][1].next
      importantnodes[1][1].next.prev = importantnodes[1][0].prev
      tmp = copy.deepcopy(importantcircles)
      #case where crossing has non-braided smoothing: goes to 0
      if crossnodes[0].next==crossnodes[2]:
        newcircles = []
        for nod in crossnodes:
          for circ in tmp:
            if nod in circ[1]:
              circ[1].remove(nod)
        #splitting one circle
        if len(tmp)==1:
          for i in range(2):
            if len(tmp[0][1])==0:
              newcirc = []
              newcircles.append(newcirc)
            else:
              newcirc = [tmp[0][1][0]]
              tmp[0][1].pop(0)
              while True:
                last = newcirc[-1]
                if (not last.next in newcirc):
                  newcirc.append(last.next)
                  tmp[0][1].remove(last.next)
                elif (not last.prev in newcirc):
                  newcirc.append(last.prev)
                  tmp[0][1].remove(last.prev)
                else:
                  break
              newcircles.append(newcirc)
        #merging two distinct circles
        else:
          tmp = tmp[0][1]+tmp[1][1]
          newcirc = [tmp[0]]
          tmp.pop(0)
          while True:
            last = newcirc[-1]
            if (not last.next in newcirc):
              newcirc.append(last.next)
              tmp.remove(last.next)
            elif (not last.prev in newcirc):
              newcirc.append(last.prev)
              tmp.remove(last.prev)
            else:
              break
          newcircles.append(newcirc)
        #replace old circles with new ones and add result to tmpstates & tmpcoeffs
        newlabeledcircles = [['0',circ] for circ in newcircles]
        for circ in importantcircles:
          st.labeledcircles.remove(circ)
        for circ in newlabeledcircles:
          st.labeledcircles.append(circ)
        st.fixLabels()
        st.smoothing.saddleMove(crossing)
        tmpstates.append(st)
        tmpcoeffs.append(self.coefficients[k])
      #case where crossing has braided smoothing
      else:
        for nod in crossnodes:
          for circ in importantcircles:
            if nod in circ[1]:
              circ[1].remove(nod)
        #if positive crossing, labels remain the same. if negative crossing, change labels based on previous labels of circle(s)
        if st.smoothing.braid.word[crossing-1]>0:
          st.fixLabels()
          st.smoothing.saddleMove(crossing)
          tmpstates.append(st)
          tmpcoeffs.append(self.coefficients[k])
        else:
          top = importantcircles[0][0]
          bottom = importantcircles[-1][0]
          if (top=='1'):
            if (bottom=='1' and len(importantcircles)==2):
              importantcircles[1][0] = 'x'
              st.fixLabels()
              st.smoothing.saddleMove(crossing)
              st1 = copy.deepcopy(st)
              tmpstates.append(st1)
              tmpcoeffs.append(-1*self.coefficients[k])
              importantcircles[0][0] = 'x'
              importantcircles[1][0] = '1'
              st.fixLabels()
              tmpstates.append(st)
              tmpcoeffs.append(self.coefficients[k])
            elif (bottom=='1'):
              importantcircles[0][0] = 'x'
              importantcircles[1][0] = 'x'
              st.fixLabels()
              st.smoothing.saddleMove(crossing)
              tmpstates.append(st)
              tmpcoeffs.append(0)
            else:
              importantcircles[0][0] = 'x'
              st.fixLabels()
              st.smoothing.saddleMove(crossing)
              tmpstates.append(st)
              tmpcoeffs.append(self.coefficients[k])
          else:
            importantcircles[-1][0] = 'x'
            st.fixLabels()
            st.smoothing.saddleMove(crossing)
            tmpstates.append(st)
            if (bottom=='1'):
              tmpcoeffs.append(-1*self.coefficients[k])
            else:
              tmpcoeffs.append(0)
      #add the new states to result
      for i in range(len(tmpstates)):
        resstates.append(tmpstates[i])
        rescoeffs.append(tmpcoeffs[i])
    #store the result
    res = Ccelement(resstates,rescoeffs)
    res = res.simplify()
    self.states = res.states
    self.coefficients = res.coefficients      


  def reidemeister2(self,crossing):
    if self.states[0].smoothing.braid.word[crossing-1]!=-self.states[0].smoothing.braid.word[crossing]:
      print("invalid Reidemeister II move")
      return
    resstates = []
    rescoeffs = []
    for k in range(len(self.states)):
      st = self.states[k]
      tmpstates = []
      tmpcoeffs = []
      crossnodes = [None,None,None,None,None,None,None,None]
      crossnumbers = st.smoothing.checkCrossings()
      for nod in st.smoothing.graph.nodes:
        if (nod.data//4==crossnumbers[crossing-1] or nod.data//4==crossnumbers[crossing]):
          if (nod.data//4==crossnumbers[crossing-1]):
            i = nod.data%4
          else:
            i = (nod.data%4)+4
          crossnodes[i] = nod
      #note: importantcircles has either one, two, or three labeled circles. 
      #the first circle always contains the upper left node
      importantcircles = []
      importantidx = [None,None,None,None,None,None,None,None]
      circno = 0
      for circ in st.labeledcircles:
        found = False
        for i in range(len(crossnodes)):
          if crossnodes[i] in circ[1]:
            if not circ in importantcircles:
              importantcircles.append(circ)
            importantidx[i] = [circno,circ[1].index(crossnodes[i])]
            found = True
        if found:
          circno += 1
      #nodes [[0,5],[2,7]]        
      importantnodes = [[None,None],[None,None]]
      for i in range(2):
        for j in range(2):
          importantnodes[i][j] = importantcircles[importantidx[2*i+5*j][0]][1][importantidx[2*i+5*j][1]]
      importantnodes[0][0].prev.next = importantnodes[0][1].next
      importantnodes[0][1].next.prev = importantnodes[0][0].prev
      importantnodes[1][0].prev.next = importantnodes[1][1].next
      importantnodes[1][1].next.prev = importantnodes[1][0].prev
      tmp = copy.deepcopy(importantcircles)
      #case where both crossings have braided smoothing: labels of circles don't change
      if (crossnodes[0].next==crossnodes[1] and crossnodes[4].next==crossnodes[5]):
        for nod in crossnodes:
          for circ in importantcircles:
            if nod in circ[1]:
              circ[1].remove(nod)
        st.fixLabels()
        st.smoothing.reidemeister2(crossing)
        tmpstates.append(st)
        tmpcoeffs.append(self.coefficients[k])
      #case where both crossings have non-braided smoothing
      if (crossnodes[0].next==crossnodes[2] and crossnodes[4].next==crossnodes[6]):
        #if isolated circle has 1 label, it dies, otherwise splits into more cases
        #isolated circle contains crossnodes[1]
        if importantcircles[importantidx[1][0]][0]=='x':
          left = importantcircles[importantidx[0][0]][0]
          right = importantcircles[importantidx[5][0]][0]
          newcircles = []
          for nod in crossnodes:
            for circ in tmp:
              if nod in circ[1]:
                circ[1].remove(nod)
          #splitting one circle, isolated circle dies
          if len(tmp)==2:
            tmp = tmp[0][1] + tmp[1][1]
            for i in range(2):
              if len(tmp)==0:
                newcirc = []
                newcircles.append(newcirc)
              else:
                newcirc = [tmp[0]]
                tmp.pop(0)
                while True:
                  last = newcirc[-1]
                  if (not last.next in newcirc):
                    newcirc.append(last.next)
                    tmp.remove(last.next)
                  elif (not last.prev in newcirc):
                    newcirc.append(last.prev)
                    tmp.remove(last.prev)
                  else:
                    break
                newcircles.append(newcirc)
            if left=='1':
              newlabeledcircles = [['1',newcircles[0]],['x',newcircles[1]]]
              for circ in importantcircles:
                st.labeledcircles.remove(circ)
              for circ in newlabeledcircles:
                st.labeledcircles.append(circ)
              st.fixLabels()
              st.smoothing.reidemeister2(crossing)
              st1 = copy.deepcopy(st)
              tmpstates.append(st1)
              tmpcoeffs.append(-1*self.coefficients[k])
              newlabeledcircles2 = [['x',newcircles[0]],['1',newcircles[1]]]
              for circ in newlabeledcircles:
                st.labeledcircles.remove(circ)
              for circ in newlabeledcircles2:
                st.labeledcircles.append(circ)
              st.fixLabels()
              tmpstates.append(st)
              tmpcoeffs.append(-1*self.coefficients[k])
            else:
              newlabeledcircles = [['x',newcircles[0]],['x',newcircles[1]]]
              for circ in importantcircles:
                st.labeledcircles.remove(circ)
              for circ in newlabeledcircles:
                st.labeledcircles.append(circ)
              st.fixLabels()
              st.smoothing.reidemeister2(crossing)
              tmpstates.append(st)
              tmpcoeffs.append(-1*self.coefficients[k])                
          #merging two distinct circles, isolated circle dies
          else:
            tmp = tmp[0][1]+tmp[1][1]+tmp[2][1]
            newcirc = [tmp[0]]
            tmp.pop(0)
            while True:
              last = newcirc[-1]
              if (not last.next in newcirc):
                newcirc.append(last.next)
                tmp.remove(last.next)
              elif (not last.prev in newcirc):
                newcirc.append(last.prev)
                tmp.remove(last.prev)
              else:
                break
            newcircles.append(newcirc)
            label = '1'
            if ((left=='x' and right=='1') or (left=='1' and right=='x')):
              label = 'x'
            elif (left=='x' and right=='x'):
              label = '0'
            newlabeledcircles = [[label,newcircles[0]]]
            for circ in importantcircles:
              st.labeledcircles.remove(circ)
            for circ in newlabeledcircles:
              st.labeledcircles.append(circ)
            st.fixLabels()
            st.smoothing.reidemeister2(crossing)
            tmpstates.append(st)
            tmpcoeffs.append(-1*self.coefficients[k])
      for i in range(len(tmpstates)):
        resstates.append(tmpstates[i])
        rescoeffs.append(tmpcoeffs[i])
    res = Ccelement(resstates,rescoeffs)
    res = res.simplify()
    self.states = res.states
    self.coefficients = res.coefficients  
  
  #causes a death for all circles
  def deaths(self):
    res = 0
    for i in range(len(self.states)):
      x = 1
      for circ in self.states[i].labeledcircles:
        if circ[0]=='1':
          x = 0
          break
      res += x*self.coefficients[i]
    self.states = []
    self.coefficients = [res]

  #performs the sequence of simplifying cobordism moves given in movetypes and movecrossings
  #movetypes is a list of the type of cobordism move ('s' for saddle move, 'r' for Reidemeister II move, and 'd' for death)
  #movecrossings is a list of the corresponding crossing numbers; for death, any crossing number should work
  def cobordism(self,movetypes,movecrossings,verbose=False):
    for i in range(len(movetypes)):
      #figure out which crossing should be changed
      before = 0
      for j in range(i):
        if movecrossings[j]<movecrossings[i]:
          if movetypes[j]=='s':
            before += 1
          else:
            before += 2
      if movetypes[i]=='s':
        self.saddleMove(movecrossings[i]-before)
      elif movetypes[i]=='r':
        self.reidemeister2(movecrossings[i]-before)
      else:
        self.deaths()
      if verbose:
        print(self)
        if (len(self.states)>0):
          self.states[0].smoothing.braid.printBraid()
        print('-------------------------------')
    for st in self.states:
      st.fixSmoothing()

  #makes self into a single labeled state of n 1-labeled circles with nodes labeled according to braid
  def birth(self,braid):
    unbraid = Braid(braid.n,[])
    unsmooth = Smoothing(unbraid,'')
    rows = [[] for i in range(braid.n)]
    for cross in braid.word:
      rownum = braid.n-abs(cross)
      if (len(rows[rownum-1])>0):
        rows[rownum-1].append(Node(prev=rows[rownum-1][-1]))
        rows[rownum-1][-2].next = rows[rownum-1][-1]
      else:
        rows[rownum-1].append(Node())
      rows[rownum-1].append(Node(prev=rows[rownum-1][-1]))
      rows[rownum-1][-2].next = rows[rownum-1][-1]
      if (len(rows[rownum])>0):
        rows[rownum].append(Node(prev=rows[rownum][-1]))
        rows[rownum][-2].next = rows[rownum][-1]
      else:
        rows[rownum].append(Node())
      rows[rownum].append(Node(prev=rows[rownum][-1]))
      rows[rownum][-2].next = rows[rownum][-1]
    for row in rows:
      if len(row)>0:
        row[0].prev = row[-1]
        row[-1].next = row[0]
    unsmooth.braid = braid
    sm = ''
    for cross in braid.word:
      if cross>0:
        sm += '0'
      else:
        sm += '1'
    unsmooth.smooth = sm
    graphnodes = []
    for row in rows:
      for nod in row:
        graphnodes.append(nod)
    unsmooth.graph = Graph(graphnodes)
    unsmooth.circles = rows
    st = LabeledState(unsmooth,'1'*braid.n)    
    self.states = [st]
    self.coefficients = [1]

  #makes a Reidemeister I move and a saddle move to add a crossing
  #braid is the boundary braid of the cobordism, crossing is the number of the crossing
  def saddleMoveUp(self,braid,crossing):
    resstates = []
    rescoeffs = []
    for i in range(len(self.states)):
      st = self.states[i]
      tmpstates = []
      tmpcoeffs = []
      importantcircles = [None,None]
      for circ in st.labeledcircles:
        for nod in circ[1]:
          if (nod.data//4 == crossing-1):
            if ((nod.data%4 <= 1) and (circ not in importantcircles)):
              importantcircles[0] = circ
            elif (circ not in importantcircles):
              importantcircles[1] = circ
      #positive crossing case
      if braid.word[crossing-1]>0:
        #case where saddle is across same circle goes to 0
        if ((importantcircles[0] is not None) and (importantcircles[1] is not None)):  
          if (importantcircles[0][0]=='1' and importantcircles[1][0]=='1'):
            importantcircles[0][0] = 'x'
            st1 = copy.deepcopy(st)
            st1.fixLabels()
            tmpstates.append(st1)
            tmpcoeffs.append(self.coefficients[i])
            importantcircles[0][0] = '1'
            importantcircles[1][0] = 'x'
            st.fixLabels()
            tmpstates.append(st)
            tmpcoeffs.append(-1*self.coefficients[i])
          elif (importantcircles[0][0]=='1'):
            importantcircles[0][0] = 'x'
            st.fixLabels()
            tmpstates.append(st)
            tmpcoeffs.append(self.coefficients[i])
          elif (importantcircles[1][0]=='1'):
            importantcircles[1][0] = 'x'
            st.fixLabels()
            tmpstates.append(st)
            tmpcoeffs.append(-1*self.coefficients[i])
      #negative crossing case
      else:
        tmpstates.append(st)
        tmpcoeffs.append(self.coefficients[i])
      for j in range(len(tmpstates)):
        resstates.append(tmpstates[j])
        rescoeffs.append(tmpcoeffs[j])
    res = Ccelement(resstates,rescoeffs)
    #res = res.simplify()
    self.states = res.states
    self.coefficients = res.coefficients



  #makes a Reidemeister II move to add two crossings
  #braid is the boundary braid of the cobordism, lcross is the number of the left crossing, rcross is the number of the right crossing
  def reidemeister2Up(self,braid,lcross,rcross):
    #TODO
    resstates = []
    rescoeffs = []
    for i in range(len(self.states)):
      st = self.states[i]
      #always have a copy of the original
      tmpstates = [st]
      tmpcoeffs = [self.coefficients[i]]
      importantcircles = [None,None]
      crossnodes = [None]*8
      st2 = copy.deepcopy(st)
      for circ in st2.labeledcircles:
        for nod in circ[1]:
          if ((nod.data//4 == lcross-1) or (nod.data//4 == rcross-1)):
            if (nod.data//4 == lcross-1):
              crossnodes[nod.data%4] = nod
            else:
              crossnodes[nod.data%4 + 4] = nod
            if ((nod.data%4 <= 1) and (circ not in importantcircles)):
              importantcircles[0] = circ
            elif (circ not in importantcircles):
              importantcircles[1] = circ
      crossnodes[0].next = crossnodes[2]
      crossnodes[2].next = crossnodes[0]
      crossnodes[1].prev = crossnodes[3]
      crossnodes[3].prev = crossnodes[1]
      crossnodes[4].next = crossnodes[6]
      crossnodes[6].next = crossnodes[4]
      crossnodes[5].prev = crossnodes[7]
      crossnodes[7].prev = crossnodes[5]
      #case where saddle is across two distinct circles
      if ((importantcircles[0] is not None) and (importantcircles[1] is not None)):
        top = importantcircles[0][0]
        bot = importantcircles[1][0]  
        tmp = copy.deepcopy(importantcircles)
        tmp = tmp[0][1]+tmp[1][1]
        newcircles = []
        #newcircles[0] is the outer circle, newcircles[1] is the inner circle
        for j in range(2):
          newcirc = [tmp[0]]
          tmp.pop(0)
          while True:
            last = newcirc[-1]
            if (not last.next in newcirc):
              newcirc.append(last.next)
              tmp.remove(last.next)
            elif (not last.prev in newcirc):
              newcirc.append(last.prev)
              tmp.remove(last.prev)
            else:
              break
          newcircles.append(newcirc)
        newlabeledcircles = [['1',newcircles[1]]]
        if (top=='1' and bot=='1'):
          newlabeledcircles.append(['1',newcircles[0]])
          for circ in importantcircles:
            st2.labeledcircles.remove(circ)
          for circ in newlabeledcircles:
            st2.labeledcircles.append(circ)
          st2.fixLabels()
          tmpstates.append(st2)
          tmpcoeffs.append(self.coefficients[i])
        elif (top=='1' or bot=='1'):
          newlabeledcircles.append(['x',newcircles[0]])
          for circ in importantcircles:
            st2.labeledcircles.remove(circ)
          for circ in newlabeledcircles:
            st2.labeledcircles.append(circ)
          st2.fixLabels()
          tmpstates.append(st2)
          tmpcoeffs.append(self.coefficients[i])
      #case where saddle is across the same circle
      else:
        if importantcircles[0] is not None:
          label = importantcircles[0][0]
          tmp = copy.deepcopy(importantcircles[0][1])
        else:
          label = importantcircles[1][0]
          tmp = copy.deepcopy(importantcircles[1][1])
        newcircles = []
        #index of the isolated circle
        circno = 0
        for j in range(3):
          newcirc = [tmp[0]]
          tmp.pop(0)
          while True:
            last = newcirc[-1]
            if last==crossnodes[1]:
              circno = j
            if (not last.next in newcirc):
              newcirc.append(last.next)
              tmp.remove(last.next)
            elif (not last.prev in newcirc):
              newcirc.append(last.prev)
              tmp.remove(last.prev)
            else:
              break
          newcircles.append(newcirc)
        newlabeledcircles = [['1',newcircles[circno]]]
        if label=='1':
          newlabeledcircles += [['1',newcircles[circno-1]],['x',newcircles[circno-2]]]
          for circ in importantcircles:
            if (circ is not None):
              st2.labeledcircles.remove(circ)
          for circ in newlabeledcircles:
            st2.labeledcircles.append(circ)
          st2.fixLabels()
          st1 = copy.deepcopy(st2)
          tmpstates.append(st1)
          tmpcoeffs.append(self.coefficients[i])
          newlabeledcircles[1][0] = 'x'
          newlabeledcircles[2][0] = '1'
          st2.fixLabels()
          tmpstates.append(st2)
          tmpcoeffs.append(self.coefficients[i])
        else:
          newlabeledcircles += [['x',newcircles[circno-1]],['x',newcircles[circno-2]]]
          for circ in importantcircles:
            if (circ is not None):
              st2.labeledcircles.remove(circ)
          for circ in newlabeledcircles:
            st2.labeledcircles.append(circ)
          st2.fixLabels()
          tmpstates.append(st2)
          tmpcoeffs.append(self.coefficients[i])
      for j in range(len(tmpstates)):
        resstates.append(tmpstates[j])
        rescoeffs.append(tmpcoeffs[j])
    res = Ccelement(resstates,rescoeffs)
    #res = res.simplify()
    self.states = res.states
    self.coefficients = res.coefficients

  #returns the result of applying the differential map to self
  def differential(self):
    #TODO
    resstates = []
    rescoeffs = []
    for st in self.states:
      return

#calculates the relative Khovanov-Jacobsson class for a cobordism given by movetypes & movecrossings to a particular braided knot "braid"
#movetypes is a list of moves; "s" represents a saddle move to add crossing, "r" represents a Reidemeister II move to add two crossings
def kjclass(braid,movetypes,movecrossings):
  res = Ccelement([],[])
  res.birth(braid)
  for i in range(len(movetypes)):
    if movetypes[i]=='s':
      res.saddleMoveUp(braid,movecrossings[i])
    elif movetypes[i]=='r':
      res.reidemeister2Up(braid,movecrossings[i][0],movecrossings[i][1])
  for st in res.states:
    st.fixSmoothing()
  res = res.simplify()
  return res



  


#main
def example1():
  b = Braid(2,[1,1,-1,1,-1])
  e = kjclass(b,['s','r','r'],[1,[2,3],[4,5]])
  print(e)
  for i in range(len(e.states)):
    e1 = Ccelement([e.states[i]],[1])
    print(e1)
    e1.mirror()
    e1.cobordism(['r','r','s','d'],[1,3,5,0])
    print(e1)

def example2():
  cr = [-2,-1,1,2]
  for i in cr:
    for j in cr:
      b = Braid(3,[i,j,1,-1,2,-2])
      e = kjclass(b,['s','s','r','r'],[1,2,[3,4],[5,6]])
      e.mirror()
      e.cobordism(['r','r','s','s','d'],[1,3,5,6,0])
      print(i, j, e)

def example3():
  b1 = Braid(3,[2,-2,1,2])
  e1 = kjclass(b1,['r','s','s'],[[1,2],3,4])
  e1.mirror()
  e1.cobordism(['s','s','r','d'],[1,2,3,0],verbose=True)
  print(e1)

  b2 = Braid(3,[2,-2,1,2,-1,1])
  e2 = kjclass(b2,['r','s','s','r'],[[1,2],3,4,[5,6]])
  e2.cobordism(['r'],[5])
  e2.mirror()
  e2.cobordism(['s','s','r','d'],[1,2,3,0])
  print(e2)

def example4():
  b1 = Braid(4,[-3,-2,3,3,1,-2,-1,-3,2,-1,-2])
  b1.printBraid()
  e1 = kjclass(b1,['r','r','r','r','s','s','s'],[[1,3],[9,11],[4,8],[5,7],2,6,10])
  """
  for i in range(len(e1.states)):
    e2 = Ccelement([e1.states[i]],[1])
    e2.mirror()
    e2.cobordism(['s','s','s','r','r','r','r','d'],[2,6,10,5,4,1,9,0])
    print(e2)
  """
  print("KJ-class: ", e1)
  e1.mirror()
  print("mirror: ", e1)
  e1.cobordism(['s','s','s','r','r','r','r','d'],[2,6,10,5,4,1,9,0],verbose=True)
  #e1.cobordism(['s','s','r','r'],[4,5,3,2],verbose=True)

def example5():
  b9_46 = Braid(4,[2,1,-2,3,1,2,-1,-3,-3,2,3])
  b9_46.printBraid()
  s1 = Smoothing(b9_46,'00111000100')
  #print(s1.circles)
  phi1 = LabeledState(s1,'xxxx')
  #print(phi1,labeledcircles)
  s2 = Smoothing(b9_46,'00101000101')
  phi2 = LabeledState(s2,'xxxx')
  s3 = Smoothing(b9_46,'01101000100')
  phi3 = LabeledState(s3,'xxxx')
  s4 = Smoothing(b9_46,'01001001100')
  phi4 = LabeledState(s4,'xxxx')
  e1 = Ccelement([phi1,phi2,phi3,phi4],[1,1,1,1])
  print(e1)
  e1.cobordism(['s','s','s','r','r','r','r','d'],[2,6,10,5,4,1,9,0],verbose=True)
  print(e1)

def example6():
  b10_140 = Braid(4,[-3,-2,3,-2,1,2,2,2,-3,-2,-2,-2,-1])
  e = kjclass(b10_140,['r','r','r','r','r','s','s','s'],[[1,3],[5,13],[6,12],[7,11],[8,10],2,4,9])
  print(e)
  e.mirror()
  e.cobordism(['s','s','s','r','r','r','r','r','d'],[5,10,12,4,3,2,1,11,0])
  print(e)

def example7():
  b = Braid(3,[2,1,-1,2])
  s = Smoothing(b,'1001')
  phi = LabeledState(s,'1x')
  e = Ccelement([phi],[1])
  print(e)

def example8():
  b = Braid(3,[2,1,-1,2])
  e = kjclass(b,['r','s','s'],[[2,3],1,4])
  e = e.simplify()
  print(e)

example8()
